#include "StdAfx.h"
#include "iDIDSolver.h"
#include "ExpandTimeSteps.h"
#include <iosfwd>

//////////////////////////////////////////////////////////////////////////
//Author	:	Ross Conroy ross.conroy@tees.ac.uk
//Date		:	21/05/2014
//
//This class is intended to solve i-DIDs (Interactive Dynamic Influence
//Diagrams). THis is achieved with the following steps
//Algorithm
//1.Expand model DID's and iDID to same number of time steps
//2.Solve Level 0 DID's
//3.Apply policies from level 0 DID's to iDID
//4.Solve i-DID by generating policy with DID information added
//////////////////////////////////////////////////////////////////////////

iDIDSolver::iDIDSolver(void)
{
}

//Follows the steps above
Domain * iDIDSolver::Solve(Domain * inputIDID, list<Domain *> models, int numExpansions)
{
	//1. Expand Models and IDID
	cout << "Expand Models" << endl;
	models = ExpandModels(models, numExpansions);
	ExpandTimeSteps ets;
	ets.Expand(inputIDID, numExpansions, ExpandTimeSteps::IDID);	

	NodeList modelNodes = GetModelNodes(inputIDID);
	SetModelNodesNumbers(modelNodes, models.size());

	//2.Solve models - Loop through models and solve them
	cout << "Solve Models" << endl;
	int i = 0;
	for(list<Domain *>::const_iterator it = models.begin(); it != models.end(); it++)
	{		
		Domain * model = *it;
		model->compile();
		model->updatePolicies();

		stringstream ss;
		ss << "\\nets\\Model" << i << ".net";
		model->saveAsNet(ss.str());
		i++;
	}

	//3.Add policies from solved models to i-DID
	cout << "Add policies to i-DID" << endl;
	inputIDID = AddModelPoliciesToIDID(inputIDID, models);

	inputIDID->saveAsNet("C:\\Nets\\outputiDID.net");

	//4.Solve i-DID
	cout << "Solve i-DID" << endl;
	
	inputIDID->compile();
	inputIDID->updatePolicies();
	
	
	inputIDID->saveAsNet("\\nets\\iDIDOutput.net");

	return inputIDID;
}


//This method adds the policies from the models to the IDID domain
//this is done by copying the contents of the tables
//Algorithm
//For each model action chance node
//	For Each Model
//		Find corresponding action node
//		Add contents of action node to model node table
//	End For
//End For
//For each model observation chance node
//	For Each Model
//		Find Corresponding observation Node
//		Append Contents of observation node table to model observation node
//	End For
//End For
Domain * iDIDSolver::AddModelPoliciesToIDID(Domain * inputIDID, list<Domain *> models)
{
	//Action nodes
	NodeList modelActionNodes = GetModelActionNodes(inputIDID);
	for(NodeList::const_iterator itma = modelActionNodes.begin(); itma != modelActionNodes.end(); itma++)
	{
		DiscreteChanceNode * node = (DiscreteChanceNode *)*itma;
		NumberList table;

		for(list<Domain *>::const_iterator itmd = models.begin() ; itmd!= models.end(); itmd ++)
		{
			Domain * domain = *itmd;
			string nodeName = node->getName();
			DiscreteDecisionNode * decisionNode = (DiscreteDecisionNode *)domain->getNodeByName(nodeName);

			NumberList policy = decisionNode->getTable()->getData();

			for(NumberList::const_iterator itp = policy.begin(); itp != policy.end(); itp++)
			{
				table.push_back(*itp);
			}
		}	

		node->getTable()->setData(table);
	}

	//Observation Nodes
	NodeList modelObservationNodes = GetModelObservationNodes(inputIDID);
	for(NodeList::const_iterator itmo = modelObservationNodes.begin(); itmo != modelObservationNodes.end(); itmo++)
	{
		DiscreteChanceNode * node = (DiscreteChanceNode *)*itmo;
		NumberList table;

		for(list<Domain *>::const_iterator itmd = models.begin() ; itmd!= models.end(); itmd ++)
		{
			Domain * domain = *itmd;
			string nodeName = node->getName();
			DiscreteChanceNode * observationNode = (DiscreteChanceNode *)domain->getNodeByName(nodeName);

			NumberList cpt = observationNode->getTable()->getData();

			for(NumberList::const_iterator itp = cpt.begin(); itp != cpt.end(); itp++)
			{
				table.push_back(*itp);
			}
		}

		node->getTable()->setData(table);
	}

	return inputIDID;
}

//Gets the node name for the equivilent decision node
string iDIDSolver::GetModelEquivilentDecisionNodeName(string modelNodeName)
{
	string equivilentName = modelNodeName;
	//TODO
	//equivilentName.erase(std::remove(equivilentName.begin(), equivilentName.end(), "j"), equivilentName.end());

	return equivilentName;
}

//Loops through the models and calls the expand time steps class on each
list<Domain *> iDIDSolver::ExpandModels(list<Domain *> models, int expandSteps)
{
	list<Domain *> expandedModels;

	for(list<Domain *>::const_iterator it = models.begin(); it != models.end(); it++)
	{		
		Domain * model = *it;
		ExpandTimeSteps ets;
		Domain * expandedDomain = ets.Expand(model, expandSteps, ExpandTimeSteps::DID);
		expandedModels.push_back(expandedDomain);
	}

	return expandedModels;
}

//Loops through a List of model nodes and sets up the number of models
void iDIDSolver::SetModelNodesNumbers(NodeList modelNodes, int numModels)
{
	for(NodeList::const_iterator it = modelNodes.begin(); it != modelNodes.end(); it++)
	{
		DiscreteChanceNode * node = (DiscreteChanceNode *)*it;

		node->setNumberOfStates(numModels);

		for(int i = 0; i < numModels; i++)
		{
			stringstream ss;
			ss << "Model" << i;
			node->setStateLabel(i, ss.str());
		}
	}
}

//Extracts the model nodes by name
NodeList iDIDSolver::GetModelNodes(Domain * inputDID)
{
	NodeList nodes = inputDID->getNodes();
	NodeList modelNodes;

	for(NodeList::const_iterator it = nodes.begin(); it != nodes.end(); it++)
	{
		Node * node = *it;

		if (node->getName().find("modj") != std::string::npos) 
		{
			modelNodes.push_back(node);
		}
	}

	return modelNodes;
}


NodeList iDIDSolver::GetModelActionNodes(Domain * inputDID)
{
	NodeList nodes = inputDID->getNodes();
	NodeList modelNodes;

	for(NodeList::const_iterator it = nodes.begin(); it != nodes.end(); it++)
	{
		Node * node = *it;

		if (node->getName().find("aj") != std::string::npos) 
		{
			modelNodes.push_back(node);
		}
	}

	return modelNodes;
}

NodeList iDIDSolver::GetModelObservationNodes(Domain * inputDID)
{
	NodeList nodes = inputDID->getNodes();
	NodeList modelNodes;

	for(NodeList::const_iterator it = nodes.begin(); it != nodes.end(); it++)
	{
		Node * node = *it;

		if (node->getName().find("oj") != std::string::npos) 
		{
			string nodeName = node->getName();
			modelNodes.push_back(node);
		}
	}

	return modelNodes;
}
